{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "b518b04cbfe0"
      },
      "source": [
        "##### Copyright 2020 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "906e07f6e562"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "daf323e33b84"
      },
      "source": [
        "# 从头编写训练循环"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2440f6e0c5ef"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>     <a target=\"_blank\" href=\"https://tensorflow.google.cn/guide/keras/writing_a_training_loop_from_scratch\"><img src=\"https://tensorflow.google.cn/images/tf_logo_32px.png\">在 TensorFlow.org 上查看</a>\n",
        "</td>\n",
        "  <td>     <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs-l10n/blob/master/site/zh-cn/guide/keras/writing_a_training_loop_from_scratch.ipynb\"><img src=\"https://tensorflow.google.cn/images/colab_logo_32px.png\">在 Google Colab 中运行</a>\n",
        "</td>\n",
        "  <td>     <a target=\"_blank\" href=\"https://github.com/tensorflow/docs-l10n/blob/master/site/zh-cn/guide/keras/writing_a_training_loop_from_scratch.ipynb\"><img src=\"https://tensorflow.google.cn/images/GitHub-Mark-32px.png\">在 GitHub 上查看源代码</a>\n",
        "</td>\n",
        "  <td>     <a href=\"https://storage.googleapis.com/tensorflow_docs/docs-l10n/site/zh-cn/guide/keras/writing_a_training_loop_from_scratch.ipynb\"><img src=\"https://tensorflow.google.cn/images/download_logo_32px.png\">下载笔记本</a>\n",
        "</td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8d4ac441b1fc"
      },
      "source": [
        "## 设置"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ae2407ad926f"
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf\n",
        "from tensorflow import keras\n",
        "from tensorflow.keras import layers\n",
        "import numpy as np"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0f5a253901f8"
      },
      "source": [
        "## 简介\n",
        "\n",
        "Keras 提供了默认的训练与评估循环 `fit()` 和 `evaluate()`。[使用内置方法进行训练和评估](https://tensorflow.google.cn/guide/keras/train_and_evaluate/)指南中介绍了它们的用法。\n",
        "\n",
        "如果想要自定义模型的学习算法，同时又能利用 `fit()` 的便利性（例如，使用 `fit()` 训练 GAN），则可以将 `Model` 类子类化并实现自己的 `train_step()` 方法，此方法可在 `fit()` 中重复调用。<a href=\"https://tensorflow.google.cn/guide/keras/customizing_what_happens_in_fit/\" data-md-type=\"link\">自定义 `fit()` 的功能</a>指南对此进行了介绍。\n",
        "\n",
        "现在，如果您想对训练和评估进行低级别控制，则应当从头开始编写自己的训练和评估循环。这正是本指南要介绍的内容。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "f4f47351a3ec"
      },
      "source": [
        "## 使用 `GradientTape`：第一个端到端示例\n",
        "\n",
        "在 `GradientTape` 作用域内调用模型使您可以检索层的可训练权重相对于损失值的梯度。利用优化器实例，您可以使用上述梯度来更新这些变量（可以使用 `model.trainable_weights` 检索这些变量）。\n",
        "\n",
        "我们考虑一个简单的 MNIST 模型："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "aaa775ce7dab"
      },
      "outputs": [],
      "source": [
        "inputs = keras.Input(shape=(784,), name=\"digits\")\n",
        "x1 = layers.Dense(64, activation=\"relu\")(inputs)\n",
        "x2 = layers.Dense(64, activation=\"relu\")(x1)\n",
        "outputs = layers.Dense(10, name=\"predictions\")(x2)\n",
        "model = keras.Model(inputs=inputs, outputs=outputs)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "d8b02a5759cf"
      },
      "source": [
        "我们使用带自定义训练循环的 mini-batch 梯度对其进行训练。\n",
        "\n",
        "首先，我们需要优化器、损失函数和数据集："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "f2c6257b8d02"
      },
      "outputs": [],
      "source": [
        "# Instantiate an optimizer.\n",
        "optimizer = keras.optimizers.SGD(learning_rate=1e-3)\n",
        "# Instantiate a loss function.\n",
        "loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)\n",
        "\n",
        "# Prepare the training dataset.\n",
        "batch_size = 64\n",
        "(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()\n",
        "x_train = np.reshape(x_train, (-1, 784))\n",
        "x_test = np.reshape(x_test, (-1, 784))\n",
        "\n",
        "# Reserve 10,000 samples for validation.\n",
        "x_val = x_train[-10000:]\n",
        "y_val = y_train[-10000:]\n",
        "x_train = x_train[:-10000]\n",
        "y_train = y_train[:-10000]\n",
        "\n",
        "# Prepare the training dataset.\n",
        "train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))\n",
        "train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)\n",
        "\n",
        "# Prepare the validation dataset.\n",
        "val_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val))\n",
        "val_dataset = val_dataset.batch(batch_size)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5c30285b1a2e"
      },
      "source": [
        "下面是我们的训练循环：\n",
        "\n",
        "- 我们打开一个遍历各周期的 `for` 循环\n",
        "- 对于每个周期，我们打开一个分批遍历数据集的 `for` 循环\n",
        "- 对于每个批次，我们打开一个 `GradientTape()` 作用域\n",
        "- 在此作用域内，我们调用模型（前向传递）并计算损失\n",
        "- 在作用域之外，我们检索模型权重相对于损失的梯度\n",
        "- 最后，我们根据梯度使用优化器来更新模型的权重"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "5bf4c10ceb50"
      },
      "outputs": [],
      "source": [
        "epochs = 2\n",
        "for epoch in range(epochs):\n",
        "    print(\"\\nStart of epoch %d\" % (epoch,))\n",
        "\n",
        "    # Iterate over the batches of the dataset.\n",
        "    for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):\n",
        "\n",
        "        # Open a GradientTape to record the operations run\n",
        "        # during the forward pass, which enables auto-differentiation.\n",
        "        with tf.GradientTape() as tape:\n",
        "\n",
        "            # Run the forward pass of the layer.\n",
        "            # The operations that the layer applies\n",
        "            # to its inputs are going to be recorded\n",
        "            # on the GradientTape.\n",
        "            logits = model(x_batch_train, training=True)  # Logits for this minibatch\n",
        "\n",
        "            # Compute the loss value for this minibatch.\n",
        "            loss_value = loss_fn(y_batch_train, logits)\n",
        "\n",
        "        # Use the gradient tape to automatically retrieve\n",
        "        # the gradients of the trainable variables with respect to the loss.\n",
        "        grads = tape.gradient(loss_value, model.trainable_weights)\n",
        "\n",
        "        # Run one step of gradient descent by updating\n",
        "        # the value of the variables to minimize the loss.\n",
        "        optimizer.apply_gradients(zip(grads, model.trainable_weights))\n",
        "\n",
        "        # Log every 200 batches.\n",
        "        if step % 200 == 0:\n",
        "            print(\n",
        "                \"Training loss (for one batch) at step %d: %.4f\"\n",
        "                % (step, float(loss_value))\n",
        "            )\n",
        "            print(\"Seen so far: %s samples\" % ((step + 1) * batch_size))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "d600076b7be0"
      },
      "source": [
        "## 指标的低级处理\n",
        "\n",
        "我们在此基本循环中添加指标监视。\n",
        "\n",
        "在这种从头开始编写的训练循环中，您可以轻松重用内置指标（或编写的自定义指标）。下面列出了具体流程：\n",
        "\n",
        "- 在循环开始时实例化指标\n",
        "- 在每个批次后调用 `metric.update_state()`\n",
        "- 当您需要显示指标的当前值时，调用 `metric.result()`\n",
        "- 当您需要清除指标的状态（通常在周期结束）时，调用 `metric.reset_states()`\n",
        "\n",
        "我们利用这些知识在每个周期结束时基于验证数据计算 `SparseCategoricalAccuracy`"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2602509b16c7"
      },
      "outputs": [],
      "source": [
        "# Get model\n",
        "inputs = keras.Input(shape=(784,), name=\"digits\")\n",
        "x = layers.Dense(64, activation=\"relu\", name=\"dense_1\")(inputs)\n",
        "x = layers.Dense(64, activation=\"relu\", name=\"dense_2\")(x)\n",
        "outputs = layers.Dense(10, name=\"predictions\")(x)\n",
        "model = keras.Model(inputs=inputs, outputs=outputs)\n",
        "\n",
        "# Instantiate an optimizer to train the model.\n",
        "optimizer = keras.optimizers.SGD(learning_rate=1e-3)\n",
        "# Instantiate a loss function.\n",
        "loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)\n",
        "\n",
        "# Prepare the metrics.\n",
        "train_acc_metric = keras.metrics.SparseCategoricalAccuracy()\n",
        "val_acc_metric = keras.metrics.SparseCategoricalAccuracy()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9111a5cc87dc"
      },
      "source": [
        "下面是我们的训练和评估循环："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "654e2311dbff"
      },
      "outputs": [],
      "source": [
        "import time\n",
        "\n",
        "epochs = 2\n",
        "for epoch in range(epochs):\n",
        "    print(\"\\nStart of epoch %d\" % (epoch,))\n",
        "    start_time = time.time()\n",
        "\n",
        "    # Iterate over the batches of the dataset.\n",
        "    for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):\n",
        "        with tf.GradientTape() as tape:\n",
        "            logits = model(x_batch_train, training=True)\n",
        "            loss_value = loss_fn(y_batch_train, logits)\n",
        "        grads = tape.gradient(loss_value, model.trainable_weights)\n",
        "        optimizer.apply_gradients(zip(grads, model.trainable_weights))\n",
        "\n",
        "        # Update training metric.\n",
        "        train_acc_metric.update_state(y_batch_train, logits)\n",
        "\n",
        "        # Log every 200 batches.\n",
        "        if step % 200 == 0:\n",
        "            print(\n",
        "                \"Training loss (for one batch) at step %d: %.4f\"\n",
        "                % (step, float(loss_value))\n",
        "            )\n",
        "            print(\"Seen so far: %d samples\" % ((step + 1) * batch_size))\n",
        "\n",
        "    # Display metrics at the end of each epoch.\n",
        "    train_acc = train_acc_metric.result()\n",
        "    print(\"Training acc over epoch: %.4f\" % (float(train_acc),))\n",
        "\n",
        "    # Reset training metrics at the end of each epoch\n",
        "    train_acc_metric.reset_states()\n",
        "\n",
        "    # Run a validation loop at the end of each epoch.\n",
        "    for x_batch_val, y_batch_val in val_dataset:\n",
        "        val_logits = model(x_batch_val, training=False)\n",
        "        # Update val metrics\n",
        "        val_acc_metric.update_state(y_batch_val, val_logits)\n",
        "    val_acc = val_acc_metric.result()\n",
        "    val_acc_metric.reset_states()\n",
        "    print(\"Validation acc: %.4f\" % (float(val_acc),))\n",
        "    print(\"Time taken: %.2fs\" % (time.time() - start_time))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "940d8d9fae83"
      },
      "source": [
        "## 使用 `tf.function` 加快训练步骤的速度\n",
        "\n",
        "TensorFlow 2.0 中的默认运行时为 [Eager Execution](https://tensorflow.google.cn/guide/eager)。因此，上面的训练循环会以 Eager 模式执行。\n",
        "\n",
        "这对于调试非常有用，但计算图编译具有确定的性能优势。将您的计算描述为静态计算图可以使框架应用全局性能优化。当框架受约束而以贪心方式一个接一个地执行运算，而又不知道接下来会发生什么时，便无法做到这一点。\n",
        "\n",
        "以张量为输入的任何函数都可以编译为静态计算图。只需添加一个 `@tf.function` 装饰器，具体如下所示："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fdacc2d48ade"
      },
      "outputs": [],
      "source": [
        "@tf.function\n",
        "def train_step(x, y):\n",
        "    with tf.GradientTape() as tape:\n",
        "        logits = model(x, training=True)\n",
        "        loss_value = loss_fn(y, logits)\n",
        "    grads = tape.gradient(loss_value, model.trainable_weights)\n",
        "    optimizer.apply_gradients(zip(grads, model.trainable_weights))\n",
        "    train_acc_metric.update_state(y, logits)\n",
        "    return loss_value\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ab61b0bf3126"
      },
      "source": [
        "我们对评估步骤执行相同的操作："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "da4828fd8ef7"
      },
      "outputs": [],
      "source": [
        "@tf.function\n",
        "def test_step(x, y):\n",
        "    val_logits = model(x, training=False)\n",
        "    val_acc_metric.update_state(y, val_logits)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "d552377968f1"
      },
      "source": [
        "现在，我们使用编译后的训练步骤重新运行训练循环："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "d69d73c94e44"
      },
      "outputs": [],
      "source": [
        "import time\n",
        "\n",
        "epochs = 2\n",
        "for epoch in range(epochs):\n",
        "    print(\"\\nStart of epoch %d\" % (epoch,))\n",
        "    start_time = time.time()\n",
        "\n",
        "    # Iterate over the batches of the dataset.\n",
        "    for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):\n",
        "        loss_value = train_step(x_batch_train, y_batch_train)\n",
        "\n",
        "        # Log every 200 batches.\n",
        "        if step % 200 == 0:\n",
        "            print(\n",
        "                \"Training loss (for one batch) at step %d: %.4f\"\n",
        "                % (step, float(loss_value))\n",
        "            )\n",
        "            print(\"Seen so far: %d samples\" % ((step + 1) * batch_size))\n",
        "\n",
        "    # Display metrics at the end of each epoch.\n",
        "    train_acc = train_acc_metric.result()\n",
        "    print(\"Training acc over epoch: %.4f\" % (float(train_acc),))\n",
        "\n",
        "    # Reset training metrics at the end of each epoch\n",
        "    train_acc_metric.reset_states()\n",
        "\n",
        "    # Run a validation loop at the end of each epoch.\n",
        "    for x_batch_val, y_batch_val in val_dataset:\n",
        "        test_step(x_batch_val, y_batch_val)\n",
        "\n",
        "    val_acc = val_acc_metric.result()\n",
        "    val_acc_metric.reset_states()\n",
        "    print(\"Validation acc: %.4f\" % (float(val_acc),))\n",
        "    print(\"Time taken: %.2fs\" % (time.time() - start_time))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8977d77a8095"
      },
      "source": [
        "速度快了很多，对吗？"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "b5b5a54d339a"
      },
      "source": [
        "## 对模型跟踪的损失进行低级处理\n",
        "\n",
        "层和模型以递归方式跟踪调用 `self.add_loss(value)` 的层在前向传递过程中创建的任何损失。可在前向传递结束时通过属性 `model.losses` 获得标量损失值的结果列表。\n",
        "\n",
        "如果要使用这些损失分量，应将它们求和并添加到训练步骤的主要损失中。\n",
        "\n",
        "考虑下面这个层，它会产生活动正则化损失："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4ec7c4b16596"
      },
      "outputs": [],
      "source": [
        "class ActivityRegularizationLayer(layers.Layer):\n",
        "    def call(self, inputs):\n",
        "        self.add_loss(1e-2 * tf.reduce_sum(inputs))\n",
        "        return inputs\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6b12260b8bf2"
      },
      "source": [
        "我们构建一个使用它的超简单模型："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "57afe49e6b93"
      },
      "outputs": [],
      "source": [
        "inputs = keras.Input(shape=(784,), name=\"digits\")\n",
        "x = layers.Dense(64, activation=\"relu\")(inputs)\n",
        "# Insert activity regularization as a layer\n",
        "x = ActivityRegularizationLayer()(x)\n",
        "x = layers.Dense(64, activation=\"relu\")(x)\n",
        "outputs = layers.Dense(10, name=\"predictions\")(x)\n",
        "\n",
        "model = keras.Model(inputs=inputs, outputs=outputs)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "aadb58115c13"
      },
      "source": [
        "我们的训练步骤现在应当如下所示："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "cf674776a0d2"
      },
      "outputs": [],
      "source": [
        "@tf.function\n",
        "def train_step(x, y):\n",
        "    with tf.GradientTape() as tape:\n",
        "        logits = model(x, training=True)\n",
        "        loss_value = loss_fn(y, logits)\n",
        "        # Add any extra losses created during the forward pass.\n",
        "        loss_value += sum(model.losses)\n",
        "    grads = tape.gradient(loss_value, model.trainable_weights)\n",
        "    optimizer.apply_gradients(zip(grads, model.trainable_weights))\n",
        "    train_acc_metric.update_state(y, logits)\n",
        "    return loss_value\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0af04732fe78"
      },
      "source": [
        "## 总结\n",
        "\n",
        "现在，您已了解如何使用内置训练循环以及从头开始编写自己的训练循环。\n",
        "\n",
        "总之，下面是一个简单的端到端示例，它将您在本指南中学到的所有知识串联起来：一个在 MNIST 数字上训练的 DCGAN。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9fb325331a1e"
      },
      "source": [
        "## 端到端示例：从头开始的 GAN 训练循环\n",
        "\n",
        "您可能熟悉生成对抗网络 (GAN)。通过学习图像训练数据集的隐分布（图像的“隐空间”），GAN 可以生成看起来极为真实的新图像。\n",
        "\n",
        "一个 GAN 由两部分组成：一个“生成器”模型（可将隐空间中的点映射到图像空间中的点）和一个“判别器”模型，后者是一个可以区分真实图像（来自训练数据集）与虚假图像（生成器网络的输出）之间差异的分类器。\n",
        "\n",
        "GAN 训练循环如下所示：\n",
        "\n",
        "1. 训练判别器。\n",
        "\n",
        "- 在隐空间中对一批随机点采样。\n",
        "- 通过“生成器”模型将这些点转换为虚假图像。\n",
        "- 获取一批真实图像，并将它们与生成的图像组合。\n",
        "- 训练“判别器”模型以对生成的图像与真实图像进行分类。\n",
        "\n",
        "1. 训练生成器。\n",
        "\n",
        "- 在隐空间中对随机点采样。\n",
        "- 通过“生成器”网络将这些点转换为虚假图像。\n",
        "- 获取一批真实图像，并将它们与生成的图像组合。\n",
        "- 训练“生成器”模型以“欺骗”判别器，并将虚假图像分类为真实图像。\n",
        "\n",
        "有关 GAN 工作原理的详细介绍，请参阅 [Deep Learning with Python](https://www.manning.com/books/deep-learning-with-python)。\n",
        "\n",
        "我们来实现这个训练循环。首先，创建用于区分虚假数字和真实数字的判别器："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fabf9cef3400"
      },
      "outputs": [],
      "source": [
        "discriminator = keras.Sequential(\n",
        "    [\n",
        "        keras.Input(shape=(28, 28, 1)),\n",
        "        layers.Conv2D(64, (3, 3), strides=(2, 2), padding=\"same\"),\n",
        "        layers.LeakyReLU(alpha=0.2),\n",
        "        layers.Conv2D(128, (3, 3), strides=(2, 2), padding=\"same\"),\n",
        "        layers.LeakyReLU(alpha=0.2),\n",
        "        layers.GlobalMaxPooling2D(),\n",
        "        layers.Dense(1),\n",
        "    ],\n",
        "    name=\"discriminator\",\n",
        ")\n",
        "discriminator.summary()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "73396eb6daf9"
      },
      "source": [
        "接着，我们创建一个生成器网络，它可以将隐向量转换成形状为 `(28, 28, 1)`（表示 MNIST 数字）的输出："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "821d203bfb3e"
      },
      "outputs": [],
      "source": [
        "latent_dim = 128\n",
        "\n",
        "generator = keras.Sequential(\n",
        "    [\n",
        "        keras.Input(shape=(latent_dim,)),\n",
        "        # We want to generate 128 coefficients to reshape into a 7x7x128 map\n",
        "        layers.Dense(7 * 7 * 128),\n",
        "        layers.LeakyReLU(alpha=0.2),\n",
        "        layers.Reshape((7, 7, 128)),\n",
        "        layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding=\"same\"),\n",
        "        layers.LeakyReLU(alpha=0.2),\n",
        "        layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding=\"same\"),\n",
        "        layers.LeakyReLU(alpha=0.2),\n",
        "        layers.Conv2D(1, (7, 7), padding=\"same\", activation=\"sigmoid\"),\n",
        "    ],\n",
        "    name=\"generator\",\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "f0d6d54a78a0"
      },
      "source": [
        "这是关键部分：训练循环。如您所见，训练非常简单。训练步骤函数仅有 17 行代码。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3a11c875142e"
      },
      "outputs": [],
      "source": [
        "# Instantiate one optimizer for the discriminator and another for the generator.\n",
        "d_optimizer = keras.optimizers.Adam(learning_rate=0.0003)\n",
        "g_optimizer = keras.optimizers.Adam(learning_rate=0.0004)\n",
        "\n",
        "# Instantiate a loss function.\n",
        "loss_fn = keras.losses.BinaryCrossentropy(from_logits=True)\n",
        "\n",
        "\n",
        "@tf.function\n",
        "def train_step(real_images):\n",
        "    # Sample random points in the latent space\n",
        "    random_latent_vectors = tf.random.normal(shape=(batch_size, latent_dim))\n",
        "    # Decode them to fake images\n",
        "    generated_images = generator(random_latent_vectors)\n",
        "    # Combine them with real images\n",
        "    combined_images = tf.concat([generated_images, real_images], axis=0)\n",
        "\n",
        "    # Assemble labels discriminating real from fake images\n",
        "    labels = tf.concat(\n",
        "        [tf.ones((batch_size, 1)), tf.zeros((real_images.shape[0], 1))], axis=0\n",
        "    )\n",
        "    # Add random noise to the labels - important trick!\n",
        "    labels += 0.05 * tf.random.uniform(labels.shape)\n",
        "\n",
        "    # Train the discriminator\n",
        "    with tf.GradientTape() as tape:\n",
        "        predictions = discriminator(combined_images)\n",
        "        d_loss = loss_fn(labels, predictions)\n",
        "    grads = tape.gradient(d_loss, discriminator.trainable_weights)\n",
        "    d_optimizer.apply_gradients(zip(grads, discriminator.trainable_weights))\n",
        "\n",
        "    # Sample random points in the latent space\n",
        "    random_latent_vectors = tf.random.normal(shape=(batch_size, latent_dim))\n",
        "    # Assemble labels that say \"all real images\"\n",
        "    misleading_labels = tf.zeros((batch_size, 1))\n",
        "\n",
        "    # Train the generator (note that we should *not* update the weights\n",
        "    # of the discriminator)!\n",
        "    with tf.GradientTape() as tape:\n",
        "        predictions = discriminator(generator(random_latent_vectors))\n",
        "        g_loss = loss_fn(misleading_labels, predictions)\n",
        "    grads = tape.gradient(g_loss, generator.trainable_weights)\n",
        "    g_optimizer.apply_gradients(zip(grads, generator.trainable_weights))\n",
        "    return d_loss, g_loss, generated_images\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fa6bd6292488"
      },
      "source": [
        "我们通过在各个图像批次上重复调用 `train_step` 来训练 GAN。\n",
        "\n",
        "由于我们的判别器和生成器是卷积神经网络，因此您将在 GPU 上运行此代码。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "b6a4e3d42262"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "\n",
        "# Prepare the dataset. We use both the training & test MNIST digits.\n",
        "batch_size = 64\n",
        "(x_train, _), (x_test, _) = keras.datasets.mnist.load_data()\n",
        "all_digits = np.concatenate([x_train, x_test])\n",
        "all_digits = all_digits.astype(\"float32\") / 255.0\n",
        "all_digits = np.reshape(all_digits, (-1, 28, 28, 1))\n",
        "dataset = tf.data.Dataset.from_tensor_slices(all_digits)\n",
        "dataset = dataset.shuffle(buffer_size=1024).batch(batch_size)\n",
        "\n",
        "epochs = 1  # In practice you need at least 20 epochs to generate nice digits.\n",
        "save_dir = \"./\"\n",
        "\n",
        "for epoch in range(epochs):\n",
        "    print(\"\\nStart epoch\", epoch)\n",
        "\n",
        "    for step, real_images in enumerate(dataset):\n",
        "        # Train the discriminator & generator on one batch of real images.\n",
        "        d_loss, g_loss, generated_images = train_step(real_images)\n",
        "\n",
        "        # Logging.\n",
        "        if step % 200 == 0:\n",
        "            # Print metrics\n",
        "            print(\"discriminator loss at step %d: %.2f\" % (step, d_loss))\n",
        "            print(\"adversarial loss at step %d: %.2f\" % (step, g_loss))\n",
        "\n",
        "            # Save one generated image\n",
        "            img = tf.keras.preprocessing.image.array_to_img(\n",
        "                generated_images[0] * 255.0, scale=False\n",
        "            )\n",
        "            img.save(os.path.join(save_dir, \"generated_img\" + str(step) + \".png\"))\n",
        "\n",
        "        # To limit execution time we stop after 10 steps.\n",
        "        # Remove the lines below to actually train the model!\n",
        "        if step > 10:\n",
        "            break"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "a92959ac630b"
      },
      "source": [
        "就是这样！在 Colab GPU 上进行约 30 秒钟的训练后，您将获得漂亮的虚假 MNIST 数字。"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "writing_a_training_loop_from_scratch.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
